Comparing ESM-based models and RNASamba models for predicting coding and noncoding transcripts¶

Keith Cheveralls
March 2024

This notebook documents the visualizations that were used to compare the performance of ESM-based models and RNASamba models trained to predict whether transcripts are coding or noncoding. This was motivated by developing an approach that used ESM embeddings to identifying sORFs for the peptigate pipeline.

The predictions from ESM-based models and RNASamba models on which this notebook depends were generated outside of this notebook. Predictions from ESM-based models were generated using the commands namespaced under the plmutils orf-classification CLI. Predictions from RNASamba models were generated using the script found in the /scripts/rnasamba subdirectory of this repo. The CLI commands that were used are briefly documented in the sections below.

In [1]:
import io
import pathlib
import pandas as pd
import seaborn as sns
import numpy as np
from Bio import SeqIO
import matplotlib.pyplot as plt

from plmutils.models import calc_metrics

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'
/tmp/ipykernel_19454/2906734431.py:3: DeprecationWarning:
Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466

  import pandas as pd

Dataset metadata¶

The metadata associated with the 16 species used for these comparisons is included below for completeness. Note that the plots in this notebook label species using the species_id defined in this metadata (rather than the full species name).

In [2]:
metadata_csv_content = """
species_id	species_common_name	root_url	genome_name	cdna_endpoint	ncrna_endpoint	genome_abbreviation
hsap	human	https://ftp.ensembl.org/pub/release-111/fasta/homo_sapiens/	Homo_sapiens.GRCh38	cdna/Homo_sapiens.GRCh38.cdna.all.fa.gz	ncrna/Homo_sapiens.GRCh38.ncrna.fa.gz	GRCh38
scer	yeast	https://ftp.ensemblgenomes.ebi.ac.uk/pub/fungi/release-58/fasta/saccharomyces_cerevisiae/	Saccharomyces_cerevisiae.R64-1-1	cdna/Saccharomyces_cerevisiae.R64-1-1.cdna.all.fa.gz	ncrna/Saccharomyces_cerevisiae.R64-1-1.ncrna.fa.gz	R64-1-1
cele	worm	https://ftp.ensemblgenomes.ebi.ac.uk/pub/metazoa/release-58/fasta/caenorhabditis_elegans/	Caenorhabditis_elegans.WBcel235	cdna/Caenorhabditis_elegans.WBcel235.cdna.all.fa.gz	ncrna/Caenorhabditis_elegans.WBcel235.ncrna.fa.gz	WBcel235
atha	arabadopsis	https://ftp.ensemblgenomes.ebi.ac.uk/pub/plants/release-58/fasta/arabidopsis_thaliana/	Arabidopsis_thaliana.TAIR10	cdna/Arabidopsis_thaliana.TAIR10.cdna.all.fa.gz	ncrna/Arabidopsis_thaliana.TAIR10.ncrna.fa.gz	TAIR10
dmel	drosophila	https://ftp.ensemblgenomes.ebi.ac.uk/pub/metazoa/release-58/fasta/drosophila_melanogaster/	Drosophila_melanogaster.BDGP6.46	cdna/Drosophila_melanogaster.BDGP6.46.cdna.all.fa.gz	ncrna/Drosophila_melanogaster.BDGP6.46.ncrna.fa.gz	BDGP6.46
ddis	dictyostelium_discoideum	https://ftp.ensemblgenomes.ebi.ac.uk/pub/protists/release-58/fasta/dictyostelium_discoideum/	Dictyostelium_discoideum.dicty_2.7	cdna/Dictyostelium_discoideum.dicty_2.7.cdna.all.fa.gz	ncrna/Dictyostelium_discoideum.dicty_2.7.ncrna.fa.gz	dicty_2.7
mmus	mouse	https://ftp.ensembl.org/pub/release-111/fasta/mus_musculus/	Mus_musculus.GRCm39	cdna/Mus_musculus.GRCm39.cdna.all.fa.gz	ncrna/Mus_musculus.GRCm39.ncrna.fa.gz	GRCm39
drer	zebrafish	https://ftp.ensembl.org/pub/release-111/fasta/danio_rerio/	Danio_rerio.GRCz11	cdna/Danio_rerio.GRCz11.cdna.all.fa.gz	ncrna/Danio_rerio.GRCz11.ncrna.fa.gz	GRCz11
ggal	chicken	https://ftp.ensembl.org/pub/release-111/fasta/gallus_gallus/	Gallus_gallus.bGalGal1.mat.broiler.GRCg7b	cdna/Gallus_gallus.bGalGal1.mat.broiler.GRCg7b.cdna.all.fa.gz	ncrna/Gallus_gallus.bGalGal1.mat.broiler.GRCg7b.ncrna.fa.gz	bGalGal1.mat.broiler.GRCg7b
oind	rice	https://ftp.ensemblgenomes.ebi.ac.uk/pub/plants/release-58/fasta/oryza_indica/	Oryza_indica.ASM465v1	cdna/Oryza_indica.ASM465v1.cdna.all.fa.gz	ncrna/Oryza_indica.ASM465v1.ncrna.fa.gz	ASM465v1
zmay	maize	https://ftp.ensemblgenomes.ebi.ac.uk/pub/plants/release-58/fasta/zea_mays/	Zea_mays.Zm-B73-REFERENCE-NAM-5.0	cdna/Zea_mays.Zm-B73-REFERENCE-NAM-5.0.cdna.all.fa.gz	ncrna/Zea_mays.Zm-B73-REFERENCE-NAM-5.0.ncrna.fa.gz	Zm-B73-REFERENCE-NAM-5.0
xtro	frog	https://ftp.ensembl.org/pub/release-111/fasta/xenopus_tropicalis/	Xenopus_tropicalis.UCB_Xtro_10.0	cdna/Xenopus_tropicalis.UCB_Xtro_10.0.cdna.all.fa.gz	ncrna/Xenopus_tropicalis.UCB_Xtro_10.0.ncrna.fa.gz	UCB_Xtro_10.0
rnor	rat	https://ftp.ensembl.org/pub/release-111/fasta/rattus_norvegicus/	Rattus_norvegicus.mRatBN7.2	cdna/Rattus_norvegicus.mRatBN7.2.cdna.all.fa.gz	ncrna/Rattus_norvegicus.mRatBN7.2.ncrna.fa.gz	mRatBN7
amel	honeybee	https://ftp.ensemblgenomes.ebi.ac.uk/pub/metazoa/release-58/fasta/apis_mellifera/	Apis_mellifera.Amel_HAv3.1	cdna/Apis_mellifera.Amel_HAv3.1.cdna.all.fa.gz	ncrna/Apis_mellifera.Amel_HAv3.1.ncrna.fa.gz	Amel_HAv3.1
spom	fission_yeast	https://ftp.ensemblgenomes.ebi.ac.uk/pub/fungi/release-58/fasta/schizosaccharomyces_pombe/	Schizosaccharomyces_pombe.ASM294v2	cdna/Schizosaccharomyces_pombe.ASM294v2.cdna.all.fa.gz	ncrna/Schizosaccharomyces_pombe.ASM294v2.ncrna.fa.gz	ASM294v2
tthe	tetrahymena	https://ftp.ensemblgenomes.ebi.ac.uk/pub/protists/release-58/fasta/tetrahymena_thermophila/	Tetrahymena_thermophila.JCVI-TTA1-2.2	cdna/Tetrahymena_thermophila.JCVI-TTA1-2.2.cdna.all.fa.gz	ncrna/Tetrahymena_thermophila.JCVI-TTA1-2.2.ncrna.fa.gz	JCVI-TTA1-2.2
"""

metadata = pd.read_csv(io.StringIO(metadata_csv_content), sep='\t')
metadata.head()
Out[2]:
species_id species_common_name root_url genome_name cdna_endpoint ncrna_endpoint genome_abbreviation
0 hsap human https://ftp.ensembl.org/pub/release-111/fasta/... Homo_sapiens.GRCh38 cdna/Homo_sapiens.GRCh38.cdna.all.fa.gz ncrna/Homo_sapiens.GRCh38.ncrna.fa.gz GRCh38
1 scer yeast https://ftp.ensemblgenomes.ebi.ac.uk/pub/fungi... Saccharomyces_cerevisiae.R64-1-1 cdna/Saccharomyces_cerevisiae.R64-1-1.cdna.all... ncrna/Saccharomyces_cerevisiae.R64-1-1.ncrna.f... R64-1-1
2 cele worm https://ftp.ensemblgenomes.ebi.ac.uk/pub/metaz... Caenorhabditis_elegans.WBcel235 cdna/Caenorhabditis_elegans.WBcel235.cdna.all.... ncrna/Caenorhabditis_elegans.WBcel235.ncrna.fa.gz WBcel235
3 atha arabadopsis https://ftp.ensemblgenomes.ebi.ac.uk/pub/plant... Arabidopsis_thaliana.TAIR10 cdna/Arabidopsis_thaliana.TAIR10.cdna.all.fa.gz ncrna/Arabidopsis_thaliana.TAIR10.ncrna.fa.gz TAIR10
4 dmel drosophila https://ftp.ensemblgenomes.ebi.ac.uk/pub/metaz... Drosophila_melanogaster.BDGP6.46 cdna/Drosophila_melanogaster.BDGP6.46.cdna.all... ncrna/Drosophila_melanogaster.BDGP6.46.ncrna.f... BDGP6.46

Heatmap plotting functions¶

These are functions used later in the notebook to generate heatmap visualizations of the matrices of model performance metrics for all pairs of training and test species.

In [3]:
def plot_heatmap(df, column='accuracy', model_name='unknown', ax=None, **heatmap_kwargs):
    """
    Plot the values in the given column as a square heatmap of training vs test species
    (with training species on the x-axis and test species on the y-axis).

    Note: "training species" is the species used to train the model and "test species"
    is the species used to test each trained model.
    """
    df = df.pivot(index='test_species_id', columns='training_species_id', values=column)

    if ax is None:
        plt.figure(figsize=(8, 6))
        ax = plt.gca()

    sns.heatmap(
        df,
        cmap="coolwarm",
        annot=True,
        annot_kws={"size": 6},
        fmt=".1f",
        square=True,
        ax=ax,
        **heatmap_kwargs
    )

    name = column.replace('_', ' ')
    if name.lower() == 'mcc':
        name = name.upper()
    else:
        name = name[0].upper() + name[1:]

    ax.set_xlabel('Training species')
    ax.set_ylabel('Test species')
    ax.set_title(f'{name} | {model_name}')
In [4]:
def plot_heatmaps(df_left, df_right, column, model_names):
    """
    Plot a row of three heatmaps: one for the left dataframe, one for the right dataframe,
    and the third (the rightmost) for the difference between the two (right minus left).
    """
    fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(18, 5))

    df_merged = pd.merge(df_left, df_right, on=('training_species_id', 'test_species_id'))
    df_merged[column] = df_merged[f'{column}_y'] - df_merged[f'{column}_x']

    plot_heatmap(df_left, column=column, model_name=model_names[0], ax=axs[0])
    plot_heatmap(df_right, column=column, model_name=model_names[1], ax=axs[1])
    plot_heatmap(df_merged, column=column, model_name='difference', ax=axs[2], vmin=-1, vmax=1)

ESM-based model predictions¶

These predictions were generated using the plmutils orf-prediction CLI.

First, download the Ensembl datasets listed in the user-provided metadata CSV file (see above for the file used with this notebook):

plmutils orf-prediction download-data \
    output/data/ensembl-dataset-metadata.tsv \
    output/data/

Next, construct deduplicated sets of coding and noncoding transcripts. Deduplication is achieved by clustering transcripts by sequence identity and retaining only one representative sequence from each cluster.

plmutils orf-prediction construct-data \
    output/data/ensembl-dataset-metadata.tsv \
    output/data/ \
    --subsample-factor 1

Next, find putative ORFs from coding and noncoding transcripts, retain only the longest putative ORF from each transcript, and generate the embedding of the protein sequence for which it codes:

plmutils orf-prediction translate-and-embed \
    output/data/processed/final/coding-dedup-ssx1/transcripts

plmutils orf-prediction translate-and-embed \
    output/data/processed/final/noncoding-dedup-ssx1/transcripts

Finally, train models using these embeddings to predict whether a given ORF orginated from a coding or noncoding transcript. Separate models are trained on, and used to make predictions for, each species. This results in a matrix of model performance metrics for all pairs of species (one used to train the model, the other to evaluate it). The --output-dirpath in the command below corresponds to the directories passed to the calc_metrics_from_smallesm_results function defined below. (This command was run manually with and without --max-length 100 to train models on all ORFs and only sORFs, respectively).

plmutils orf-prediction train-and-evaluate \
    --coding-dirpath output/data/processed/final/coding-dedup-ssx1/embeddings/esm2_t6_8M_UR50D \
    --noncoding-dirpath output/data/processed/final/noncoding-dedup-ssx1/embeddings/esm2_t6_8M_UR50D \
    --output-dirpath output/data/esm-model-results-ssx1-all
In [5]:
def calc_metrics_from_smallesm_results(results_dirpath, max_length=None):
    """
    Calculate classification metrics from ESM-based model results.
    """
    all_metrics = []
    prediction_filepaths = pathlib.Path(results_dirpath).glob('*.csv')
    for prediction_filepath in prediction_filepaths:
        df = pd.read_csv(prediction_filepath)

        if max_length is not None:
            df = df.loc[df.sequence_length < max_length]

        metrics = calc_metrics(
            y_true=(df.true_label == 'coding'),
            y_pred_proba=df.predicted_probability.values,
        )
        metrics['training_species_id'] = df.iloc[0].training_species_id
        metrics['test_species_id'] = df.iloc[0].testing_species_id
        metrics['num_coding'] = (df.true_label == 'coding').sum()
        metrics['num_noncoding'] = (df.true_label != 'coding').sum()

        all_metrics.append(metrics)
    df = pd.DataFrame(all_metrics)
    df['true_negative_rate'] = df.num_true_negative / df.num_noncoding
    return df
In [6]:
metrics_esm_trained_all_eval_all = calc_metrics_from_smallesm_results(
    '../output/results/2024-03-01-esm-model-results-ssx1-all/',
    max_length=None,
)
metrics_esm_trained_all_eval_short = calc_metrics_from_smallesm_results(
    '../output/results/2024-03-01-esm-model-results-ssx1-all/',
    max_length=100,
)
metrics_esm_trained_short_eval_all = calc_metrics_from_smallesm_results(
    '../output/results/2024-02-29-esm-model-results-ssx1-max-length-100/',
    max_length=None,
)
metrics_esm_trained_short_eval_short = calc_metrics_from_smallesm_results(
    '../output/results/2024-02-29-esm-model-results-ssx1-max-length-100/',
    max_length=100,
)
In [7]:
metrics_esm_trained_all_eval_all.head()
Out[7]:
auc_roc accuracy precision recall mcc num_true_positive num_false_positive num_true_negative num_false_negative num_positive num_negative training_species_id test_species_id num_coding num_noncoding true_negative_rate
0 0.803246 0.820781 0.818813 0.998590 0.246917 18411 4074 366 26 18437 4440 ddis rnor 18437 4440 0.082432
1 0.957956 0.949619 0.969926 0.970958 0.799962 15513 481 2299 464 15977 2780 amel dmel 15977 2780 0.826978
2 0.825174 0.859946 0.858987 0.999687 0.216374 15972 2622 158 5 15977 2780 oind dmel 15977 2780 0.056835
3 0.861818 0.500659 1.000000 0.500417 0.021989 11393 0 11 11374 22767 11 rnor tthe 22767 11 1.000000
4 0.986828 0.956949 0.972110 0.973813 0.867413 19449 558 4580 523 19972 5138 rnor cele 19972 5138 0.891397

Compare ESM-based models trained on all ORFs and only sORFs¶

In [8]:
# models trained on either all ORFs or only sORFs and evaluated on only sORFs.
plot_heatmaps(
    metrics_esm_trained_all_eval_short,
    metrics_esm_trained_short_eval_short,
    column='mcc',
    model_names=('ESM-based (trained all, eval short)', 'ESM-based (trained short, eval short)')
)
No description has been provided for this image
In [9]:
# models trained only on sORFs and evaluated on all or only sORFs.
plot_heatmaps(
    metrics_esm_trained_short_eval_all,
    metrics_esm_trained_short_eval_short,
    column='mcc',
    model_names=('ESM-based (trained short, eval all)', 'ESM-based (trained short, eval short)')
)
No description has been provided for this image
In [10]:
# models trained on all ORFs or only sORFs, but evaluated on all sequences.
plot_heatmaps(
    metrics_esm_trained_all_eval_all,
    metrics_esm_trained_short_eval_all,
    column='mcc',
    model_names=('ESM-based (trained all, eval all)', 'ESM-based (trained short, eval all)')
)
No description has been provided for this image

RNASamba predictions¶

These predictions were generated by the script plm-utils/scripts/rnasamba/train_and_evaluate.py using the same datasets of deduplicated coding and noncoding transcripts generated by the plmutils orf-prediction construct-data command describe above.

To train RNASamba models on all sequences:

python scripts/rnasamba-comparison/train_and_evaluate.py \
--coding-dirpath output/data/processed/final/coding-dedup-ssx1/transcripts \
--noncoding-dirpath output/data/processed/final/noncoding-dedup-ssx1/transcripts \
--output-dirpath 2024-02-28-rnasamba-results-ssx1-all \

To train RNASamba models on transcripts corresponding to sORFs:

python scripts/rnasamba-comparison/train_and_evaluate.py \
--coding-dirpath output/data/processed/final/coding-dedup-ssx1/transcripts \
--noncoding-dirpath output/data/processed/final/noncoding-dedup-ssx1/transcripts \
--output-dirpath output/data/2024-02-28-rnasamba-results-ssx1-min-peptide-length-100 \
--max-length 100

The --output-dirpath above corresponds to the directory passed to the calc_metrics_from_rnasamba_results function below.

In [11]:
def calc_metrics_from_rnasamba_results(rnasamba_results_dirpath):
    """
    Aggregate the results from RNASamba models trained in the script 
    `scripts/rnasamba-comparison/train_and_evaluate.py`.
    """
    all_metrics = []
    dirpaths = [p for p in rnasamba_results_dirpath.glob('trained-on*') if p.is_dir()]
    for dirpath in dirpaths:

        # dirnames are of the form 'trained-on-{species_id}-filtered'.
        training_species_id = dirpath.stem.split('-')[2]

        prediction_filepaths = dirpath.glob('*.tsv')
        for prediction_filepath in prediction_filepaths:

            # filenames are of the form '{species_id}-preds.csv'.
            test_species_id = prediction_filepath.stem.split('-')[0]

            df = pd.read_csv(prediction_filepath, sep=',')
            metrics = calc_metrics(
                y_true=(df.true_label == 'coding'), y_pred_proba=df.coding_score.values
            )
            metrics['training_species_id'] = training_species_id
            metrics['test_species_id'] = test_species_id
            metrics['num_coding'] = (df.true_label == 'coding').sum()
            metrics['num_noncoding'] = (df.true_label != 'coding').sum()

            all_metrics.append(metrics)

    df = pd.DataFrame(all_metrics)
    df['true_negative_rate'] = df.num_true_negative / df.num_noncoding
    return df
In [12]:
# models trained and tested on all transcripts.
rnasamba_results_dirpath_all = pathlib.Path(
    '../output/results/2024-02-23-rnasamba-models-clustered-ssx3/'
)

# models trained and tested only on transcripts whose longest ORFs are sORFs.
rnasamba_results_dirpath_short = pathlib.Path(
    '../output/results/2024-02-28-rnasamba-results-ssx1-max-peptide-length-100/'
)

metrics_rs_trained_all_eval_all = calc_metrics_from_rnasamba_results(rnasamba_results_dirpath_all)
metrics_rs_trained_short_eval_short = calc_metrics_from_rnasamba_results(rnasamba_results_dirpath_short)
/home/keith/miniforge3/envs/esm-py311-env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1497: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/home/keith/miniforge3/envs/esm-py311-env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1497: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/home/keith/miniforge3/envs/esm-py311-env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1497: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/home/keith/miniforge3/envs/esm-py311-env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1497: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/home/keith/miniforge3/envs/esm-py311-env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1497: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/home/keith/miniforge3/envs/esm-py311-env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1497: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

Compare RNASamba models trained on all or only sORFs¶

In [13]:
plot_heatmaps(
    metrics_rs_trained_all_eval_all,
    metrics_rs_trained_short_eval_short,
    column='mcc',
    model_names=('RNASamba (all)', 'RNASamba (short)')
)
No description has been provided for this image

Compare RNASamba and ESM-based models¶

These are the most important plots in this notebook. They compare the performance of ESM-based models to RNASamba models by plotting the heatmap of performance metrics side by side.

Models trained and evaluated on all transcripts (for RNASamba) or ORFs (for ESM-based)¶

In [14]:
# overall performance (MCC metric)
plot_heatmaps(
    metrics_rs_trained_all_eval_all,
    metrics_esm_trained_all_eval_all,
    column='mcc',
    model_names=('RNASamba (all)', 'ESM-based (all)')
)
No description has been provided for this image
In [15]:
# recall (also the true positive rate, or num_true_positive / num_coding)
plot_heatmaps(
    metrics_rs_trained_all_eval_all,
    metrics_esm_trained_all_eval_all,
    column='recall',
    model_names=('RNASamba (all)', 'ESM-based (all)')
)
No description has been provided for this image
In [16]:
# the true negative rate.
plot_heatmaps(
    metrics_rs_trained_all_eval_all,
    metrics_esm_trained_all_eval_all,
    column='true_negative_rate',
    model_names=('RNASamba (all)', 'ESM-based (all)')
)
No description has been provided for this image

Models trained only on short sequences (< 100aa)¶

For RNASamba, this means the models were trained only on transcripts whose longest ORF was an sORF (less than 100aa long).

Note that the class imbalance in this case is severe (most species do not have many coding transcripts whose longest ORF is an sORF) and this likely at least partly explains why the RNASamba models perform so poorly, as we do not compensate for the class imbalance during training (while we do compensate for it when training the ESM-based models).

In [17]:
plot_heatmaps(
    metrics_rs_trained_short_eval_short,
    metrics_esm_trained_short_eval_short,
    column='mcc',
    model_names=('RNASamba (short)', 'ESM-based (short)')
)
No description has been provided for this image
In [18]:
plot_heatmaps(
    metrics_rs_trained_short_eval_short,
    metrics_esm_trained_short_eval_short,
    column='recall',
    model_names=('RNASamba (short)', 'ESM-based (short)')
)
No description has been provided for this image
In [19]:
plot_heatmaps(
    metrics_rs_trained_short_eval_short,
    metrics_esm_trained_short_eval_short,
    column='true_negative_rate',
    model_names=('RNASamba (short)', 'ESM-based (short)')
)
No description has been provided for this image
In [ ]:

Aside: blasting against peptipedia¶

We were curious whether some of the false positives from ESM-based models represented genuine sORFs from lncRNAs (which are annotated as noncoding). As a way to examine this, we blasted all of the putative ORFs against peptipedia, and plotted the distribution of max evalues from putative sORFs for which the ESM-based model made either true and false positive predictions. If the model correctly identifies genuine sORFs from lncRNAs, we'd expect to see an enrichment of low evalues among the false positives.

The command plmutils orf-classification blast-peptipedia was used to generate the directory of blast results that are loaded and concatenated by concat_smallesm_results function below.

In [20]:
def concat_smallesm_results(results_dirpath):
    """
    Load and concatenate the predictions from esm-based models.
    """
    dfs = []
    prediction_filepaths = pathlib.Path(results_dirpath).glob('*.csv')
    for prediction_filepath in prediction_filepaths:
        dfs.append(pd.read_csv(prediction_filepath))

    return pd.concat(dfs)
In [21]:
# predictions from models trained on all putative ORFs.
esm_trained_all_preds = concat_smallesm_results(
    '../output/results/2024-03-01-esm-model-results-ssx1-all/'
)
In [22]:
# predictions from models trained on short peptides (< 100aa).
esm_trained_short_preds = concat_smallesm_results(
    '../output/results/2024-02-29-esm-model-results-ssx1-max-length-100/'
)
In [23]:
esm_trained_all_preds.shape, esm_trained_short_preds.shape
Out[23]:
((7766064, 6), (7766064, 6))
In [24]:
esm_trained_short_preds.head()
Out[24]:
sequence_id sequence_length true_label predicted_probability training_species_id testing_species_id
0 RNOR.ENSRNOT00000105380.1 124 coding 0.306121 ddis rnor
1 RNOR.ENSRNOT00000094775.1 265 coding 0.436046 ddis rnor
2 RNOR.ENSRNOT00000119508.1 594 coding 0.868186 ddis rnor
3 RNOR.ENSRNOT00000094997.1 690 coding 0.568402 ddis rnor
4 RNOR.ENSRNOT00000119131.1 390 coding 0.739940 ddis rnor
In [25]:
# count the number of peptides from coding and noncoding transcripts to make sure 
# that the class imbalance between coding and noncoding is not too severe. 
# (we only need to look at preds from one model, since each model is tested with all species).
hsap_preds = esm_trained_all_preds.loc[esm_trained_all_preds.training_species_id == 'hsap'].copy()
pd.merge(
    hsap_preds.groupby(['testing_species_id', 'true_label']).count().sequence_id,
    (
        hsap_preds.loc[hsap_preds.sequence_length < 100]
        .groupby(['testing_species_id', 'true_label'])
        .count()
        .sequence_id
    ),
    left_index=True,
    right_index=True,
    suffixes=('_all', '_short'),
)
Out[25]:
sequence_id_all sequence_id_short
testing_species_id true_label
amel coding 11725 223
noncoding 2429 1204
atha coding 30734 1800
noncoding 3638 3164
cele coding 19972 1285
noncoding 5138 5051
ddis coding 11688 1220
noncoding 28 28
dmel coding 15977 538
noncoding 2780 1617
drer coding 28385 2088
noncoding 3674 2343
ggal coding 24448 375
noncoding 19831 8242
hsap coding 55672 14928
noncoding 42157 23212
mmus coding 32128 8062
noncoding 19171 11011
oind coding 28134 2515
noncoding 205 200
rnor coding 18437 556
noncoding 4440 2840
scer coding 6013 368
noncoding 103 99
spom coding 4683 178
noncoding 1032 668
tthe coding 22767 19158
noncoding 11 11
xtro coding 27543 572
noncoding 244 243
zmay coding 39519 1002
noncoding 2673 744
In [26]:
def concat_blast_results(dirpaths):
    """
    Aggregate the blast results generated by `plmutils orf-classification blast-peptipedia`.
    """
    blast_results_columns = (
        "qseqid sseqid full_sseq pident length qlen slen mismatch gapopen qstart qend sstart send evalue bitscore"
    ).split(' ')

    dfs = []
    for dirpath in dirpaths:
        filepaths = pathlib.Path(dirpath).glob('*.tsv')
        for filepath in filepaths:
            try:
                df = pd.read_csv(filepath, sep='\t')
            except Exception:
                continue
            df.columns = blast_results_columns
            dfs.append(df)
    return pd.concat(dfs)
In [27]:
blast_results = concat_blast_results(
    [
        '../output/data/processed/final/coding-dedup-ssx1/blast-peptipedia-results/',
        '../output/data/processed/final/noncoding-dedup-ssx1/blast-peptipedia-results/',
    ]
)
In [28]:
# use the log of the evalue for readability.
blast_results['evalue'] = np.log(blast_results.evalue)

# we only need to examine the minimum evalue for all hits to each peptide.
min_evalues = blast_results.groupby('qseqid').evalue.min().reset_index()
In [29]:
# merge the minimum evalues with the model predictions.
esm_trained_short_preds_w_evalues = pd.merge(
    esm_trained_short_preds, min_evalues, left_on='sequence_id', right_on='qseqid', how='inner'
)

esm_trained_all_preds_w_evalues = pd.merge(
    esm_trained_all_preds, min_evalues, left_on='sequence_id', right_on='qseqid', how='inner'
)
In [30]:
esm_trained_short_preds_w_evalues_short_only = esm_trained_short_preds_w_evalues.loc[
    esm_trained_short_preds_w_evalues.sequence_length < 100
].copy()
In [31]:
# sanity-check: count the number of peptides that had hits in peptipedia.
(
    esm_trained_short_preds_w_evalues_short_only
    # we only need to look at one model
    .loc[esm_trained_short_preds_w_evalues_short_only.training_species_id == 'hsap']
    .groupby(['testing_species_id', 'true_label'])
    .count()
    [['sequence_id']]
)
Out[31]:
sequence_id
testing_species_id true_label
amel coding 108
noncoding 10
atha coding 534
noncoding 17
cele coding 217
noncoding 4
ddis coding 366
dmel coding 121
noncoding 12
drer coding 273
noncoding 46
ggal coding 142
noncoding 5
hsap coding 2470
noncoding 1308
mmus coding 996
noncoding 119
oind coding 198
noncoding 1
rnor coding 257
noncoding 14
scer coding 339
noncoding 1
spom coding 144
noncoding 23
tthe coding 400
xtro coding 187
noncoding 1
zmay coding 113
noncoding 4

Histograms of evalues for coding and noncoding transcripts¶

This was to determine whether the false positives were enriched for peptides that had hits in peptipedia, which would suggest that they correspond to genuine sORFs from lncRNAs (and are therefore not actually false positives).

In [32]:
# we only look at preds for short peptides from the human dataset
# because it is one of the only that has a decent number of short peptides
# with peptipedia hits and are from noncoding transcripts.
preds = esm_trained_all_preds_w_evalues.loc[
    (esm_trained_all_preds_w_evalues.training_species_id == 'hsap') &
    (esm_trained_all_preds_w_evalues.testing_species_id == 'hsap') &
    (esm_trained_all_preds_w_evalues.sequence_length < 100)
]

fig, axs = plt.subplots(1, 2, figsize=(16, 6))

min_min_evalue = -150
bins = np.arange(min_min_evalue, 0, -min_min_evalue/30)
kwargs = dict(bins=bins, density=False, alpha=0.5)

# left axis: coding transcripts
ax = axs[0]
ax.hist(
    preds[(preds.true_label == 'coding') & (preds.predicted_probability > 0.5)].evalue,
    label='True positives',
    color='blue',
    **kwargs
)
ax.hist(
    preds[(preds.true_label == 'coding') & (preds.predicted_probability < 0.5)].evalue,
    label='False negatives',
    color='red',
    **kwargs
)
ax.legend()
ax.set_xlabel('Minimum log evalue')
ax.set_ylabel('Density')
ax.set_title('Coding transcripts')

# right axis: noncoding transcripts
ax = axs[1]
ax.hist(
    preds[(preds.true_label == 'noncoding') & (preds.predicted_probability < 0.5)].evalue,
    label='True negatives',
    color='blue',
    **kwargs
)
_ = ax.hist(
    preds[(preds.true_label == 'noncoding') & (preds.predicted_probability > 0.5)].evalue,
    label='False positives',
    color='red',
    **kwargs
)
ax.legend()
ax.set_title('Noncoding transcripts')
Out[32]:
Text(0.5, 1.0, 'Noncoding transcripts')
No description has been provided for this image
In [ ]: